import logging
import random

import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn

from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub

from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig

from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
import time
import numpy as np

from minigpt4.models import policies


@registry.register_model("mini_gpt4v")
class MiniGPT4v(Blip2Base):
    """
    BLIP2 GPT-LLAMA model.
    """

    PRETRAINED_MODEL_CONFIG_DICT = {
        "pretrain_vicuna": "configs/models/minigpt4.yaml",
    }

    def __init__(
        self,
        vit_model="eva_clip_g",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp16",
        freeze_vit=True,
        llama_model="",
        prompt_path="",
        prompt_template="",
        max_txt_len=32,
        low_resource=False,  # use 8 bit and put vit in cpu
        end_sym='\n',
        lora_r = 8,
        lora_target_modules = ["q_proj","v_proj"],
        lora_alpha=16,
        # lora_r = 16,
        # lora_target_modules = ["q_proj","v_proj","v_proj"],
        lora_dropout= 0.05,
        ckpt_path = "",
        system_prompt= False,
        chat_template=False,
        token_pooling=True,
        use_grad_checkpoint_llm=False,
        max_context_len=3800,
        remove_template = False,

    ):
        super().__init__()

        self.tokenizer = self.init_tokenizer()
        self.low_resource = low_resource
        self.token_pooling = token_pooling
        self.remove_template = remove_template

        print("token pooling", self.token_pooling)


        self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
        self.max_context_len = max_context_len
        self.chat_template = chat_template

        # print('Loading VIT')
        # self.visual_encoder, self.ln_vision = self.init_vision_encoder(
        #     vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
        # )


        print("vit precision", vit_precision)
        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision
        )
        for name, param in self.visual_encoder.named_parameters():
            param.requires_grad = False
        self.visual_encoder = self.visual_encoder.eval()
        self.visual_encoder.train = disabled_train
        for name, param in self.ln_vision.named_parameters():
            param.requires_grad = False
        self.ln_vision = self.ln_vision.eval()
        self.ln_vision.train = disabled_train
        logging.info("freeze vision encoder")
        print("freeze the vision encoder")


        print('Loading VIT Done')

        # print("visual encoder shape", self.visual_encoder.pos_embed.shape)
        # assert False

        print('Loading LLAMA')


        self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

        if 'CodeLlama' in llama_model:
            self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False)  #
            self.llama_tokenizer.pad_token = "$$"
        else:
            self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)  #
            self.llama_tokenizer.pad_token = "$$"

        self.system_prompt = system_prompt

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )



        self.llama_model = LlamaForCausalLM.from_pretrained(
            llama_model,
            quantization_config=bnb_config,
            device_map={"": 0}
        )

        # self.llama_model.gradient_checkpointing_enable()
        self.llama_model = prepare_model_for_kbit_training(self.llama_model)

        # self.llama_model.print_trainable_parameters()


        print('Loading LLAMA Done')

        self.merge_n = 3

        self.llama_proj = nn.Linear(
            1408 * self.merge_n**2, self.llama_model.config.hidden_size
        )

        self.max_txt_len = max_txt_len
        self.end_sym = end_sym

        if prompt_path:
            with open(prompt_path, 'r') as f:
                raw_prompts = f.read().splitlines()
            filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
            self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
            print('Load {} training prompts'.format(len(self.prompt_list)))
            print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
        else:
            self.prompt_list = []

    def encode_img(self, image):
        device = image.device
        if len(image.shape) > 4: 
            image = image.reshape(-1, *image.shape[-3:])

        bs, ch, w, h = image.shape
        assert w % 224 == 0
        bw = w // 224
        assert h % 224 == 0
        bh = h // 224
        image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5)  # bs, bw, bh, ch, 224, 224
        image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224)

        with self.maybe_autocast():
            image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device)

            image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1])
            image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5)  # bs, bw, 16, bh, 16, hs
            image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1])

            bs, pn, hs = image_embeds.shape

            image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2))

            inputs_llama = self.llama_proj(image_embeds)
            atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama, atts_llama

    def get_context_emb(self, prompt, img_list):
        img_device = img_list[0].device
        prompt_segs = prompt.split('<ImageHere>')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids  # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]

        seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]

        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]

        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs

    def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
        if prompts is None or len(prompts) == 0:
            # prompts is not provided, just return the original image embedding
            return img_embeds, atts_img
        elif img_embeds is None:
            # prompt is provided but there is no image embedding. return the prompt embedding in right padding
            self.llama_tokenizer.padding_side = "right"
            prompt_tokens = self.llama_tokenizer(
                prompts,
                return_tensors="pt",
                padding="longest",
                add_special_tokens=False
            ).to(self.device)
            prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
            atts_prompt = prompt_tokens.attention_mask
            return prompt_embeds, atts_prompt

        else:
            # return the multi-modal embedding in right padding
            emb_lists = []

            for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
                pn = each_img_embed.shape[-2]
                if lengths is not None:
                    each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
                    each_img_embed = each_img_embed[:lengths[idx] * pn]

                p_segs = each_prompt.split('<ImageHere>')
                interleave_emb = []
                for idx, seg in enumerate(p_segs[:-1]):
                    p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
                    p_embed = self.embed_tokens(p_tokens.input_ids)
                    interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))

                wrapped_emb = torch.cat(interleave_emb, dim=1)
                p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
                p_embed = self.embed_tokens(p_tokens.input_ids)
                wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
                emb_lists.append(wrapped_emb)

            emb_lens = [emb.shape[1] for emb in emb_lists]
            pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))

            max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
            wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
            wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)

            for i, emb in enumerate(emb_lists):
                length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
                wrapped_embs[i, :length] = emb[:, :length]
                wrapped_atts[i, :length] = 1

            return wrapped_embs, wrapped_atts

    def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
        """
        Concatenate the batched input embedding and batched output embedding together.
        Both the input and the output embedding should be right padded.
        """

        input_lens = []
        cat_embs = []
        cat_atts = []

        for i in range(input_embs.size(0)):
            input_len = input_atts[i].sum()
            input_lens.append(input_len)

            cat_embs.append(
                torch.cat([
                    input_embs[i][:input_len],
                    output_embs[i],
                    input_embs[i][input_len:]
                ])
            )
            cat_atts.append(
                torch.cat([
                    input_atts[i][:input_len],
                    output_atts[i],
                    input_atts[i][input_len:]
                ])
            )
            # print('===================================')
            # print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
            # print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
            # print('check out emb: ', output_embs[i][:2])
            # print('check out pad emb: ', output_embs[i][-2:])
            # print('+++++++++++++++++++++++++++++++++++')
            #
            # print('check attn before: ', input_atts[i][:this_input_ones])
            # print('check attn after: ', input_atts[i][this_input_ones:])
            # print('check attn gt before: ', output_atts[i][:3])
            # print('check attn gt after: ', output_atts[i][-3:])

        cat_embs = torch.stack(cat_embs)
        cat_atts = torch.stack(cat_atts)
        return cat_embs, cat_atts, input_lens

    def get_conv_emb(self, conv_q, conv_a, conv_img):
        """concatenate conversation and make sure the model is only trained to regress the answer"""

        regress_embs_list = []
        targets_list = []

        batch_size = len(conv_q)
        for batch_idx in range(batch_size):
            questions, answers = conv_q[batch_idx], conv_a[batch_idx]
            assigned_imgs = conv_img[batch_idx]
            questions = [self.prompt_wrap(
                img_embeds=img,
                atts_img=None,
                prompts=[q],
                lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
            q_embs = [emb for emb, _ in questions]

            answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
            cur_emb = []
            cur_target = []
            for i in range(len(questions)):
                cur_emb.append(q_embs[i])
                cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)

                cur_emb.append(self.embed_tokens(answers[i].input_ids))
                cur_target.append(answers[i].input_ids)

            cur_emb = torch.cat(cur_emb, dim=1)
            cur_target = torch.cat(cur_target, dim=1)

            regress_embs_list.append(cur_emb)
            targets_list.append(cur_target)

        max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)

        regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
        regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
        targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100

        for batch_idx in range(batch_size):
            cur_len = regress_embs_list[batch_idx].shape[1]
            regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
            regress_attn[batch_idx, :cur_len] = 1
            targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]

        return regress_embeds, regress_attn, targets

    def preparing_embedding(self, samples):
        def remove_special_tokens(data):
            
            # if "instruction_input" in data:
            data = [instruct.replace(" [caption]","") for instruct in data]
            data = [instruct.replace(" [vqa]","") for instruct in data]
            data = [instruct.replace(" [grounding]","") for instruct in data]
            data = [instruct.replace(" [identify]","") for instruct in data]
            data = [instruct.replace(" [refer]","") for instruct in data]
            return data

        ### prepare input tokens
        if 'image' in samples:
            img_embeds, img_atts = self.encode_img(samples["image"])
        else:
            img_embeds = img_atts = None

        if 'conv_q' in samples:
            # handeling conversation datasets
            conv_q, conv_a = samples['conv_q'], samples['conv_a']

            connect_sym = samples['connect_sym'][0]
            conv_q = [q.split(connect_sym)for q in conv_q]
            conv_a = [a.split(connect_sym) for a in conv_a]
            conv_img = assign_imgs(conv_q, img_embeds)

            if self.chat_template:
                conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]

            regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
            cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]

        else:
            instruction = samples["instruction_input"] if "instruction_input" in samples else None

            # print("instruction before", instruction)
            if self.remove_template:
                instruction = remove_special_tokens(instruction)
            # print("instruction after", instruction)
                
            if self.chat_template:
                instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]

            if 'length' in samples:
                # the input is a image train (like videos)
                bsz, pn, hs = img_embeds.shape
                img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
                cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
            else:
                cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)

            ### prepare target tokens
            self.llama_tokenizer.padding_side = "right"
            text = [t + self.end_sym for t in samples["answer"]]

            regress_tokens = self.llama_tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                truncation=True,
                max_length=self.max_txt_len,
                add_special_tokens=False
            ).to(self.device)

            regress_token_ids = regress_tokens.input_ids
            regress_atts = regress_tokens.attention_mask
            part_targets = regress_token_ids.masked_fill(
                regress_token_ids == self.llama_tokenizer.pad_token_id, -100
            )

            regress_embeds = self.embed_tokens(regress_token_ids)

        return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets

    def forward(self, samples, reduction="mean"):
        # prepare the embedding to condition and the embedding to regress
        cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
            self.preparing_embedding(samples)

        # concat the embedding to condition and the embedding to regress
        inputs_embeds, attention_mask, input_lens = \
            self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)

        # get bos token embedding
        bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        bos_atts = attention_mask[:, :1]

        # add bos token at the begining
        inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
        attention_mask = torch.cat([bos_atts, attention_mask], dim=1)

        # ensemble the final targets
        targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
                             dtype=torch.long).to(self.device).fill_(-100)
        for i, target in enumerate(part_targets):
            targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target  # plus 1 for bos

        with self.maybe_autocast():
            outputs = self.llama_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=targets,
                reduction=reduction
            )
        loss = outputs.loss

        return {"loss": loss}

    @torch.no_grad()
    def generate(
        self,
        images,
        texts,
        use_nucleus_sampling=False,
        num_beams=1,
        max_new_tokens=20,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1,
        length_penalty=1,
        temperature=1,
        do_sample=False,
        stop_words_ids=[2],
        lengths=None,
    ):
        '''
            function for generate test use
        '''

        stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
            stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])

        img_embeds, atts_img = self.encode_img(images.to(self.device))
        if lengths is not None:
            image_lists = []
            img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
            for idx, img_embed in enumerate(img_embeds):
                image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
        else:
            image_lists = [[image_emb[None]] for image_emb in img_embeds]
        assert len(texts) == len(image_lists)
        batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]

        batch_size = len(batch_embs)
        max_len = max([emb.shape[1] for emb in batch_embs])
        emb_dim = batch_embs[0].shape[2]
        dtype = batch_embs[0].dtype
        device = batch_embs[0].device

        embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
        attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
        for i, emb in enumerate(batch_embs):
            emb_len = emb.shape[1]
            embs[i, -emb_len:] = emb[0]
            attn_mask[i, -emb_len:] = 1

        with self.maybe_autocast():
            outputs = self.llama_model.generate(
                inputs_embeds=embs,
                attention_mask=attn_mask,
                max_new_tokens=max_new_tokens,
                num_beams=num_beams,
                do_sample=do_sample,
                # stopping_criteria=stopping_criteria,
            )

        answers = []
        for output_token in outputs:
            if output_token[0] == 0:
                output_token = output_token[1:]
            output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
            output_texts = output_texts.split('</s>')[0]  # remove the stop sign </s>
            output_texts = output_texts.replace("<s>", "")
            output_texts = output_texts.split(r'[/INST]')[-1].strip()
            answers.append(output_texts)

        return answers

    @torch.no_grad()
    def multi_select(self, images, texts, answers, num_cand=None):
        all_losses = []
        for answer in answers:
            choice_samples = {
                'image': images,
                'instruction_input': texts,
                'answer': answer
            }
            loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
            all_losses.append(loss)
            torch.cuda.empty_cache()
        all_losses = torch.cat(all_losses, dim=-1)
        if num_cand is not None:
            for i in range(all_losses.shape[0]):
                all_losses[i, num_cand[i]:] = 9999
        output_class_ranks = torch.argsort(all_losses, dim=-1)
        return output_class_ranks.tolist()

    def predict_answers(
        self,
        samples,
        num_beams=5,
        inference_method="generate",
        max_len=10,
        min_len=1,
        num_ans_candidates=128,
        answer_list=None,
        prompt="",
        length_penalty=0,
        **kwargs
    ):
        '''
            function for open-ended VQA
        '''
        images = samples["image"].cuda()
        texts = samples["instruction_input"]

        output_text = self.generate(
            images=images,
            texts=texts,
            num_beams=num_beams,
            max_new_tokens=max_len,
            min_length=min_len,
            length_penalty=length_penalty
        )

        if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
            output_text = self._lemmatize(output_text)

        return output_text

    def predict_class(
            self,
            samples,
            num_beams=5,
            inference_method="generate",
            max_len=10,
            min_len=1,
            num_ans_candidates=5,
            answer_list=None,
            prompt="",
            length_penalty=0,
            **kwargs
    ):
        '''
            function for multi-choice VQA
        '''

        image = samples["image"].cuda()
        instruction = samples['instruction_input']
        answers = samples["choices"]
        num_cand = samples["num_choices"]

        ranks = self.multi_select(image, instruction, answers, num_cand)

        pred_ans = []
        for i, rank in enumerate(ranks):
            pred = answers[rank[0]][i]
            pred_ans.append(pred)
        return pred_ans

    def embed_tokens(self, token_ids):
        try:
            embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
        except AttributeError:
            embeds = self.llama_model.model.embed_tokens(token_ids)

        return embeds

    @classmethod
    def from_config(cls, cfg):
        vit_model = cfg.get("vit_model", "eva_clip_g")
        q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
        img_size = cfg.get("image_size")
        num_query_token = cfg.get("num_query_token")
        llama_model = cfg.get("llama_model")

        drop_path_rate = cfg.get("drop_path_rate", 0)
        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
        vit_precision = cfg.get("vit_precision", "fp16")
        freeze_vit = cfg.get("freeze_vit", True)
        freeze_qformer = cfg.get("freeze_qformer", True)
        low_resource = cfg.get("low_resource", False)

        prompt_path = cfg.get("prompt_path", "")
        prompt_template = cfg.get("prompt_template", "")
        max_txt_len = cfg.get("max_txt_len", 300)
        end_sym = cfg.get("end_sym", '\n')

        lora_r = cfg.get("lora_r",64)
        lora_alpha = cfg.get("lora_alpha",16)
        chat_template = cfg.get("chat_template",False)
        system_prompt = cfg.get("system_prompt", False)
        token_pooling = cfg.get("token_pooling",True)

        use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
        max_context_len = cfg.get("max_context_len", 3800)
        remove_template = cfg.get("remove_template", False)


        model = cls(
            vit_model=vit_model,
            img_size=img_size,
            drop_path_rate=drop_path_rate,
            use_grad_checkpoint=use_grad_checkpoint,
            vit_precision=vit_precision,
            freeze_vit=freeze_vit,
            llama_model=llama_model,
            prompt_path=prompt_path,
            prompt_template=prompt_template,
            max_txt_len=max_txt_len,
            low_resource=low_resource,
            end_sym=end_sym,
            lora_r = lora_r,
            lora_alpha = lora_alpha,
            chat_template = chat_template,
            system_prompt = system_prompt,
            token_pooling = token_pooling,
            use_grad_checkpoint_llm=use_grad_checkpoint_llm,
            max_context_len=max_context_len,
            remove_template = remove_template
        )

        ckpt_path = cfg.get("ckpt", "")  # load weights of MiniGPT-4
        if ckpt_path:
            print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
            ckpt = torch.load(ckpt_path, map_location="cpu")
            msg = model.load_state_dict(ckpt['model'], strict=False)

        return model


def assign_imgs(batched_instruct_list, batched_img_embeds):
    '''this function is used when the data is interleaved.
    the interlevaed data is separated, and this function assign
    corresponding image embeddings to each segment'''
    if len(batched_img_embeds.shape) == 3:
        batched_img_embeds = batched_img_embeds[:, None]

    batched_assigned = []

    for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
        img_idx = 0
        assigned_img = []
        n_assigned = []
        for instruct in instruct_list:
            n_img = instruct.count('<ImageHere>')
            if n_img > 0:  # this instruction include images.
                assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
                img_idx += n_img
                n_assigned.append(n_img)
            else:  # this instruction doesn't include images
                assigned_img.append(None)
                n_assigned.append(None)
        batched_assigned.append(assigned_img)

    return batched_assigned